-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Implement advanced sampling for one model path mtp/eagle #6245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis change introduces advanced multi-token prediction (MTP) sampling support in the PyTorch execution engine, enabling per-token sampling parameter control for speculative decoding. It adds new fields and methods to propagate and utilize sampling parameters (temperature, top-k, top-p, min-p) throughout the model engine, speculative metadata, and MTP worker. A new batch sampling function and corresponding tests are also included. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PyTorchModelEngine
participant MTPWorker
participant SpecMetadata
participant Sampler
User->>PyTorchModelEngine: forward(requests, ...)
PyTorchModelEngine->>PyTorchModelEngine: _prepare_tp_inputs()
PyTorchModelEngine->>SpecMetadata: update_advanced_mtp_sampling_params(...)
PyTorchModelEngine->>SpecMetadata: _set_up_advanced_mtp_sampling(...)
PyTorchModelEngine->>MTPWorker: sample_and_accept_draft_tokens(input_ids, logits, spec_metadata, ...)
alt enable_mixed_sampler
MTPWorker->>Sampler: sampling_batch(logits, temperatures, top_k, top_p, min_p)
Sampler-->>MTPWorker: sampled_tokens, log_probs
else
MTPWorker->>MTPWorker: greedy_sample(logits)
end
MTPWorker-->>PyTorchModelEngine: accepted_tokens
PyTorchModelEngine-->>User: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The line exceeds the 120-character limit as flagged by static analysis.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
264-267: Verify the global impact of setting torch.manual_seed(0).Setting a global PyTorch seed in the constructor could have unintended side effects on other operations. Consider:
- This affects all PyTorch random operations, not just sampling
- It might interfere with user-controlled randomness
- Consider using a local generator instead of global seed
Consider using a dedicated random generator for sampling operations:
- # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG - # operations that avoid torch.multinomial's CPU-GPU sync overhead - torch.manual_seed(0) + # Create dedicated generator for consistent multi-GPU sampling + # to avoid torch.multinomial's CPU-GPU sync overhead + self.sampling_generator = torch.Generator(device='cuda') + self.sampling_generator.manual_seed(0)Then pass this generator to sampling operations that need deterministic behavior.
1163-1195: Consider moving helper functions to class level for better organization.These helper functions are defined inside
_prepare_tp_inputsbut could be reused elsewhere. Consider making them class methods or static methods.Move these functions to class level:
- def get_request_temperature(request: LlmRequest) -> float: - if not request.sampling_config.temperature: - return 0.7 - temperature = request.sampling_config.temperature[0] - if 0 < temperature < 1e-2: - # temperature less than 0.01 may cause numerical errors - temperature = 0.01 - return temperature + @staticmethod + def _get_request_temperature(request: LlmRequest) -> float: + if not request.sampling_config.temperature: + return 0.7 + temperature = request.sampling_config.temperature[0] + if 0 < temperature < 1e-2: + # temperature less than 0.01 may cause numerical errors + temperature = 0.01 + return temperatureApply similar changes to the other helper functions and update the call sites accordingly.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
🔇 Additional comments (20)
tensorrt_llm/_torch/speculative/interface.py (1)
135-142: LGTM! Clean addition of sampling parameter fields.The new optional tensor fields for sampling parameters (
temperatures,top_k,top_p,min_p) are well-structured and follow the existing pattern in theSpecMetadatadataclass. The type annotations and comments are clear and appropriate.examples/llm-api/quickstart_advanced.py (2)
115-117: LGTM! Clean addition of command-line argument.The new
--use_advanced_mtp_samplerflag follows the established pattern for boolean command-line arguments with an appropriate default value.
169-170: LGTM! Proper integration of the new flag.The
use_advanced_mtp_samplerparameter is correctly passed to theMTPDecodingConfigconstructor, maintaining consistency with the command-line argument.tensorrt_llm/_torch/speculative/mtp.py (2)
11-11: LGTM! Appropriate import addition.The import of
sampling_batchfunction is correctly added to support the advanced MTP sampler functionality.
825-833: LGTM! Well-structured conditional sampling logic.The implementation demonstrates good practices:
- Backward compatibility: Maintains the existing greedy sampling as the default behavior
- Clear conditional logic: The flag-based switching is easy to understand and maintain
- Future-proofing: Acknowledges the unused
target_log_probsfor future log probability support- Clean integration: The advanced sampler integrates seamlessly with the existing acceptance algorithm
The approach minimizes risk while enabling the new advanced sampling functionality.
tensorrt_llm/llmapi/llm_args.py (1)
417-422: LGTM! Configuration improvements enhance usability.The changes improve the
MTPDecodingConfigclass by:
- Making several fields optional with sensible conservative defaults
- Adding the new
use_advanced_mtp_samplerflag to enable the advanced sampling feature- Following consistent patterns with other boolean configuration flags
The default values are appropriate:
num_nextn_predict_layers=1maintains backward compatibility- Boolean flags default to
Falsefor conservative behaviorrelaxed_topk=1andrelaxed_delta=0.provide safe starting pointsThis provides a clean API where users can enable advanced sampling by simply setting
use_advanced_mtp_sampler=Truewithout having to specify all the other parameters.tests/unittest/_torch/speculative/test_mtp.py (1)
333-401: LGTM! Good test coverage for advanced MTP sampler in greedy mode.The test implementation correctly validates the advanced PyTorch sampler functionality with proper setup of sampling parameters to enforce greedy behavior. The deterministic seeding and reuse of existing test cases ensures consistency and reproducibility.
However, note that this test only covers greedy mode (temperature ≤ 0.01). Consider adding future tests for actual advanced sampling modes (temperature > 0.01) to validate the full functionality of the advanced sampler.
tensorrt_llm/_torch/pyexecutor/model_engine.py (5)
20-20: LGTM!The import is necessary for accessing sampling configurations from request objects.
284-286: LGTM!Clear and logical detection of advanced MTP sampler mode.
382-398: LGTM!Appropriate CUDA tensor allocations for sampling parameters with correct sizes and data types.
1229-1234: LGTM!Sampling parameters are correctly collected and replicated for each token position across different request types.
Also applies to: 1317-1326, 1356-1365, 1398-1407
1511-1526: LGTM!Efficient non-blocking CUDA tensor copies and proper assignment to spec_metadata for advanced MTP sampling.
Also applies to: 1601-1607
tensorrt_llm/_torch/pyexecutor/sampler.py (8)
4-4: LGTM: Clean import additionThe additional typing imports are necessary for the new type annotations in the sampling functions.
154-167: LGTM: Well-implemented sampling functionThe function correctly implements top-k and top-p filtering with efficient in-place operations. The use of custom random sampling to avoid CPU-GPU synchronization is a good performance optimization.
169-178: LGTM: Clever sampling implementationThis function uses the Gumbel-max trick effectively to avoid CPU-GPU synchronization. The mathematical approach is sound and the performance justification is clear.
180-198: LGTM: Correct min-p implementationThe adaptive probability thresholding logic is correctly implemented, using the standard approach of scaling min_p by the maximum probability per sequence.
200-232: LGTM: Comprehensive top-k/top-p implementationThe function correctly implements both top-k and top-p filtering with proper handling of edge cases like the "at least one" guarantee. The sorting and scatter approach ensures correctness.
234-236: LGTM: Simple and correct greedy samplingClean implementation using argmax with proper tensor reshaping.
238-244: LGTM: Efficient temperature scalingCorrect implementation with efficient in-place operations and proper broadcasting.
246-264: LGTM: Well-designed batch sampling functionThis function effectively combines all sampling techniques with proper handling of greedy vs. random sampling. The temperature threshold logic and log-probability calculation are correctly implemented.
4a68f67 to
84d09a0
Compare
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The comment line exceeds the 120-character limit flagged by static analysis.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
264-266: Consider making the deterministic seed configurable.The hardcoded seed value of 0 ensures consistent multi-GPU sampling, but consider making this configurable through the PyTorchConfig to provide flexibility for different use cases while maintaining the performance benefits of avoiding CPU-GPU synchronization.
- # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG - # operations that avoid torch.multinomial's CPU-GPU sync overhead - torch.manual_seed(0) + # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG + # operations that avoid torch.multinomial's CPU-GPU sync overhead + seed = getattr(pytorch_backend_config, 'sampling_seed', 0) + torch.manual_seed(seed)
1163-1194: LGTM: Well-designed helper functions with proper edge case handling.The helper functions correctly extract sampling parameters with appropriate defaults and constraints. The temperature clamping to avoid numerical errors and top_k max value handling are particularly well thought out.
Consider extracting the magic numbers to constants:
+TEMPERATURE_MIN_THRESHOLD = 1e-2 +TEMPERATURE_MIN_VALUE = 0.01 +TOP_K_DISABLED_VALUE = 2147483647 # Max int32 def get_request_temperature(request: LlmRequest) -> float: if not request.sampling_config.temperature: return 0.7 temperature = request.sampling_config.temperature[0] - if 0 < temperature < 1e-2: - temperature = 0.01 + if 0 < temperature < TEMPERATURE_MIN_THRESHOLD: + temperature = TEMPERATURE_MIN_VALUE return temperature
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/speculative/interface.py
- tensorrt_llm/_torch/speculative/mtp.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tests/unittest/_torch/speculative/test_mtp.py (2)
333-401: LGTM! Well-structured test for the advanced MTP sampler.The new test method effectively validates that the advanced PyTorch sampler produces identical results to the standard sampler when configured for greedy mode. The test design is solid:
- Proper parameterization reusing existing test cases
- Deterministic seeding for reproducible results
- Correct configuration of sampling parameters to enforce greedy mode (temperature ≤ 0.01)
- Appropriate assertions matching the reference implementation
369-374: Greedy sampling parameters verifiedConfirmed that in
tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_temperature, any temperature below 0.01 is clamped up to 0.01. Therefore, extendingtemperatureswith 0.01 correctly enforces the intended greedy sampling boundary. No changes required.tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
20-20: LGTM: Import addition is necessary for new functionality.The
LlmRequestimport is properly placed and required for accessing sampling configuration in the advanced MTP sampler.
284-285: LGTM: Correct logic for advanced MTP sampler detection.The boolean flag correctly identifies when the advanced MTP sampler should be active by checking all necessary conditions in the proper sequence.
382-398: LGTM: Proper CUDA tensor allocation for sampling parameters.The tensor allocation correctly sizes buffers for batch_size × (max_draft_len + 1) elements, uses appropriate data types, and efficiently allocates only when the advanced sampler is enabled.
1157-1161: LGTM: Correct parameter replication for draft tokens.The sampling parameter lists are properly initialized and populated with the correct replication pattern for each request type, ensuring parameters are available for both the main token and all draft tokens.
Also applies to: 1229-1233, 1318-1326, 1357-1365, 1399-1407
1512-1526: LGTM: Efficient CUDA tensor copying and metadata integration.The implementation uses pinned memory and non-blocking copies for optimal performance, properly slices tensors to match actual usage, and cleanly integrates with the existing speculative decoding metadata structure.
Also applies to: 1602-1607
264-266: Excellent implementation of advanced MTP sampler support.The changes successfully add support for advanced sampling parameters in MTP speculative decoding with:
- Proper memory management through pre-allocated CUDA tensors
- Efficient parameter extraction and replication logic
- Clean integration with existing speculative decoding infrastructure
- Good performance considerations (pinned memory, non-blocking transfers)
- Appropriate conditional activation based on configuration flags
The implementation follows established patterns in the codebase and maintains backward compatibility while adding the new functionality.
Also applies to: 284-285, 382-398, 1157-1607
|
PR_Github #12593 [ run ] triggered by Bot |
|
Hi @netanel-haber, @nvxuanyuc has addressed the comments in PR 5627, would you mind reviewing this PR again? Thanks. |
|
PR_Github #12593 [ run ] completed with state |
84d09a0 to
7f31555
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The line exceeds the 120-character limit enforced by the linter.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
✅ Files skipped from review due to trivial changes (2)
- tensorrt_llm/_torch/speculative/interface.py
- tensorrt_llm/llmapi/llm_args.py
🚧 Files skipped from review as they are similar to previous changes (5)
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/speculative/mtp.py
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)
333-401: LGTM! Well-structured test for the advanced MTP sampler.The test correctly validates the advanced PyTorch sampler functionality by:
- Using deterministic seeding for reproducible results
- Properly configuring sampling parameters to enforce greedy mode (temperature ≤ 0.01)
- Reusing existing test cases to ensure consistent behavior with the default sampler
- Following the established test patterns in the codebase
The sampling parameter configuration looks correct for greedy mode testing.
7f31555 to
607dbc5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The line exceeds the 120 character limit. Consider breaking it into multiple lines for better readability.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
✅ Files skipped from review due to trivial changes (2)
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tensorrt_llm/_torch/speculative/mtp.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)
333-401: Well-structured test for advanced MTP sampler in greedy mode.The test method is properly implemented with correct parameterization, deterministic seed setting, and appropriate sampling parameter configuration to enforce greedy mode behavior. The test logic follows the established patterns and should effectively validate the advanced sampler functionality.
363-386: Sampling Parameter Threshold ConfirmedThe model engine clamps any non-zero temperature below 0.01 up to 0.01 and treats temperatures ≤ 0.01 as greedy mode. Your test’s use of
temperature = 0.01correctly hits that boundary.No changes required.
|
/bot run --disable-fail-fast |
|
PR_Github #12650 [ run ] triggered by Bot |
|
PR_Github #12650 [ run ] completed with state |
607dbc5 to
2953667
Compare
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)
370-370: Fix line length violation.The line exceeds the 120 character limit flagged by the linter.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
388-401: Test execution looks correct but consider broader test coverage.The test execution properly validates that the advanced sampler produces the same results as the original implementation in greedy mode, which is the expected behavior.
However, this test only covers greedy mode. Consider adding tests for the actual advanced sampling modes (temperature > 0.01, top-k < max_int, etc.) to fully validate the new functionality.
Would you like me to help generate additional test cases for non-greedy sampling modes to improve coverage of the advanced sampler functionality?
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
✅ Files skipped from review due to trivial changes (1)
- tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/llmapi/llm_args.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/speculative/mtp.py
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/speculative/test_mtp.py (3)
333-340: LGTM! Good test structure and deterministic setup.The test method signature follows the existing pattern and the deterministic seed ensures consistent behavior across runs, which is important for multi-GPU sampling scenarios.
342-346: Correct configuration for advanced sampler testing.The test properly enables the advanced MTP sampler feature through the
use_advanced_mtp_sampler=Trueflag, which is the key differentiator from the original test method.
363-387: Well-implemented parameter setup for greedy sampling mode.The sampling parameters are correctly configured to enforce greedy behavior:
- Temperature set to 0.01 (at the greedy boundary)
- top_k set to max int value (no filtering)
- top_p set to 1.0 (no filtering)
- min_p set to 0.0 (no filtering)
The logic properly accounts for each batch's draft tokens plus one additional token.
|
/bot run --disable-fail-fast |
|
PR_Github #13420 [ run ] triggered by Bot |
|
PR_Github #13420 [ run ] completed with state |
|
/bot run |
|
PR_Github #13502 [ run ] triggered by Bot |
|
PR_Github #13502 [ run ] completed with state |
| # Default to greedy mode. If true, use advanced pytorch sampling strategy. | ||
| self.enable_mixed_sampler = False | ||
| if self.model_config is not None: | ||
| self.enable_mixed_sampler = self.model_config.enable_mixed_sampler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: This could be a @property rather than a copy, to avoid potential consistency issues in the future.
…agle one model path Signed-off-by: Xuanyu Chen <[email protected]>
0ee79b3 to
bf7910b
Compare
…er logic Signed-off-by: Xuanyu Chen <[email protected]>
|
/bot run |
|
PR_Github #22067 [ run ] triggered by Bot. Commit: |
Signed-off-by: Izzy Putterman <[email protected]>
| Filters logits using adaptive probability thresholding. | ||
| """ | ||
| # Convert logits to probability distribution | ||
| probability_values = torch.nn.functional.softmax(logits, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this effectively neutralizes the temperature right? We apply temp then softmax again which undoes the scaling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or perhaps in sampling_batch_spec_dec_one_model, we should remove the first softmax and put one just before the sort in the apply top_k top_P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm im wrong here, misread something
| The logits tensor may be updated in-place. | ||
| """ | ||
| logits = apply_top_k_top_p(logits, k, p) | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can technically skip this softmax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am also wrong here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @IzzyPutterman is right: apply_min_p evaluates the softmax of the temperature-scaled logits and uses that to mask out some of the logits (set to -inf). The probs could be masked in the same way (set to 0). The resulting probs can (mind paragraph below) then be reused in apply_top_k_top_p, which masks out more logits/probs.
Every time logits/probs are masked, it is sufficient to renormalize the probs such that they sum to one, which is much cheaper than computing softmax. This is probably also why https://docs.flashinfer.ai/api/sampling.html uses function names like ..._renorm_probs.
Note that much of this is already worked out in #8581, albeit using flashinfer.sampling.
|
PR_Github #22067 [ run ] completed with state |
| ) -> torch.Tensor: | ||
| """Apply top-k and top-p masks to the logits. | ||
| If a top-p is used, this function will sort the logits tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a perf optimization, should we skip the expensive sorting / softmax / cumsum ops for top_p >=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If top_p is 1, we can skip the expensive sorting / softmax / cumsum ops.
In the latest trt llm version, it is already implemented. Please refer to https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/sampling_utils.py#L159-L171.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The skipping is not possible because in regular decoding the sampling is not captured in cuda graph.
This part is captured in cuda graph, so unless there's a kernel that determine whether to skip or not (like cpp/kernels/samplingTopPKernel.cu checkAllTopP) there's no way to check with the cpu flag need_top_p.
ixlmar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partially reviewed
| return logits | ||
|
|
||
|
|
||
| def apply_top_k_top_p( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we already have
| def top_k_top_p_sampling_batch( |
and some related functions. We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).
Also note that I am working on FlashInfer.sampling based alternatives for those functions. This upcoming PR brings support for Tensor-type k, p, temperature, etc. when FlashInfer is used. If you integrate the improvements made here for the non-FlashInfer case, this could give a quite nice feature set.
Ideally, this PR could (i) improve the existing sampling routines and (ii) use them via
| class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cf. #6245 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put up #8581 (work in progress!) to give an idea of what to expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See also TRTLLM-7723 (and TRTLLM-7152) for scope of ongoing work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I tested flashinfer with cuda graphs and it was breaking a bunch. With the generator objects its quite annoying in TRTLLM becuase in warmup we alternate between cuda graph warmup and non-cuda graph warmup, which breaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth a double check ofc, perhaps there is an easy way around it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ixlmar I think the current implementation of TopK TopP only allows all the request having the same TopK TopP value instead of individual requests having different values, please correct me if I'm wrong.
The current logic in model_engine.py didn't parse out all the sampling params into GPU tensors for cuda graph, this PR enables that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IzzyPutterman The idea of #8581 is to allow choosing between the sampling routines we have today in sampling_utils.py and those provided by FlashInfer. Both will be available as implementations of GroupedStrategySampler. SimpleGroupedStrategySampler uses the former sampling routines (non FlashInfer) and is already available in main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jhaotingc Correct. This was what I meant in the first comment:
We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).
Ideally, this PR could extend SimpleGroupedStrategySampler to allow for Tensor-type k, p, temperature, etc., in the same way as FlashInferGroupedStrategySampler does it for FlashInfer.sampling in #8581. If the GroupedStrategySampler abstraction is not viable (e.g. because the host data structures interfere with CUDA graphs), then I think we should extend top_k_top_p_sampling_batch (and the related functions) directly (promote scalar arguments to accept any broadcastable torch.Tensor) and reuse them here.
|
@ixlmar @IzzyPutterman would you mind sharing the latest plan regarding this PR? Is the plan to merge this in then add in more performant kernels? @ixlmar we can give you the access to this branch if you'd like to develop on top of this. |
@jhaotingc Thanks. As far as I can tell, the code that I think we should be extending and reusing ( |
| min_p = [] | ||
|
|
||
| # advanced mtp sampling's request preprocessing helper functions | ||
| def collect_req_spec_dec_sampling_params(request: LlmRequest, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resolution of request.sampling_config to sampling strategy has been cleaned up in #8132. See PR description for the intended semantics. The relevant function is
| def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: |
The existing function covers various corner cases already (e.g. temperature=0, top_p=1, etc.) and has extensive unit tests. Consider reusing this function here (perhaps make it "public", i.e., rename to something that does not start with _).
| """ | ||
| q = torch.empty_like(probs) | ||
| q.exponential_() | ||
| return probs.div_(q).argmax(dim=-1).view(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to admit that I am not familiar with this sampling scheme. If you happen to have a literature reference at hand, I would be curious to learn more (perhaps also include a comment stating the name of the method).
BTW, TorchSampler is using torch.multinomial and I did not notice any stream syncs. Code:
| next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Disclaimer: I might have well overlooked that torch.multinomial is syncing so far, so I would be curious to hear more on this.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I have found the answer to my first question: This uses the "Gumbel max trick" (Coderabbit even points that out in it's review...), after variable transformation from log-probabilities to probabilities. Including a corresponding remark in the doc-string might be useful for future readers.
| pin_memory=True) | ||
| self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True) | ||
|
|
||
| def _set_up_advanced_sampling(self, batch_size: int, max_draft_len: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code changes in mtp.py and eagle3.py looks very similar, perhaps something could be reused.
ixlmar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will be unavailable until next week and cannot "un-request changes" on this PR, so please feel free to merge without my explicit approval if this is needed to unblock other tasks.
[feat] Implement pytorch sampler for MTP
Description
The default behavior of the MTP pytorch decoder remains greedy sampling. Advanced sampling can be enabled via the
enable_mixed_samplerflag inTorchLlmArgs.Test Coverage
temperature <= 1e-2using the new PyTorch sampler.GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.